import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

from cbml_benchmark.utils.pmath import dist_matrix

from cbml_benchmark.losses.registry import LOSS


@LOSS.register('hyp_loss')
class HYPLoss(nn.Module):
    def __init__(self, cfg):
        super(HYPLoss, self).__init__()

        self.tau = 0.2
        self.hyp_c = 0.1
        self.cfg = cfg
        self.hyper_weight = 0.5

    def forward(self,x,dim,labels=None):
        # x0 and x1 - positive pair
        # tau - temperature
        # hyp_c - hyperbolic curvature, "0" enables sphere mode

        # if labels is not None:
        #     print(labels)

        batch_size = x.size(0)

        # feats = F.normalize(x, p=2, dim=1)
        # sim_mat = torch.matmul(feats, torch.t(feats))
        # epsilon = 1e-5
        #
        # # print(similarity.size(),sim_mat.size())
        #
        # for i in range(batch_size):
        #
        #     pos_pair_ = sim_mat[i][labels == labels[i]]
        #     pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
        #     neg_pair_ = sim_mat[i][labels != labels[i]]
        #
        #     if len(neg_pair_) < 1 or len(pos_pair_) < 1:
        #         continue
        #
        #     # mean_ = torch.mean(sim_mat[i])
        #     mean_ = self.hyper_weight * torch.mean(pos_pair_) + (1 - self.hyper_weight) * torch.mean(neg_pair_)
        #     # mean_ = (1. - self.hyper_weight) * torch.mean(sim_mat[i]) + self.hyper_weight * (
        #     #             torch.min(pos_pair_) + torch.max(neg_pair_)) / 2.
        #     # # sigma_ = torch.mean(torch.sum(torch.pow(sim_mat[i]-mean_,2)))
        #     sigma_ = torch.mean(torch.sum(torch.pow(neg_pair_ - mean_, 2)))
        #     # hyp + MVC see: https://github.com/kanshichao/CBML

        z = x.view(batch_size // self.cfg.DATA.NUM_INSTANCES, self.cfg.DATA.NUM_INSTANCES, dim)
        loss = list()
        for i in range(self.cfg.DATA.NUM_INSTANCES):
            for j in range(self.cfg.DATA.NUM_INSTANCES):
                if i != j:
                    x0 = z[:,i]
                    x1 = z[:,j]

                    if self.hyp_c == 0:
                        dist_f = lambda x, y: x @ y.t()
                    else:
                        dist_f = lambda x, y: -dist_matrix(x, y, c=self.hyp_c)
                    bsize = x0.shape[0]
                    target = torch.arange(bsize).cuda()
                    eye_mask = torch.eye(bsize).cuda() * 1e9
                    logits00 = dist_f(x0, x0) / self.tau - eye_mask
                    logits01 = dist_f(x0, x1) / self.tau
                    logits = torch.cat([logits01, logits00], dim=1)
                    logits -= logits.max(1, keepdim=True)[0].detach()
                    loss.append(F.cross_entropy(logits, target))
        loss = sum(loss) / batch_size
        return loss

